{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "tensorflow_dynamic_shapes.ipynb",
      "provenance": [],
      "collapsed_sections": [
        "FH3IRpYTta2v"
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FH3IRpYTta2v"
      },
      "source": [
        "##### Copyright 2021 The IREE Authors"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mWGa71_Ct2ug",
        "cellView": "form"
      },
      "source": [
        "#@title Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
        "# See https://llvm.org/LICENSE.txt for license information.\n",
        "# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h5s6ncerSpc5"
      },
      "source": [
        "# Dynamic Shapes\n",
        "\n",
        "This notebook\n",
        "\n",
        "1. Creates a TensorFlow program with dynamic shapes\n",
        "2. Imports that program into IREE's compiler\n",
        "3. Compiles the imported program to an IREE VM bytecode module\n",
        "4. Tests running the compiled VM module using IREE's runtime\n",
        "5. Downloads compilation artifacts for use with the native (C API) sample application"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s2bScbYkP6VZ",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "d202c522-b614-4785-f5f5-b95d2eb082c7"
      },
      "source": [
        "#@title General setup\n",
        "\n",
        "import os\n",
        "import tempfile\n",
        "\n",
        "ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n",
        "os.makedirs(ARTIFACTS_DIR, exist_ok=True)\n",
        "print(f\"Using artifacts directory '{ARTIFACTS_DIR}'\")"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Using artifacts directory '/tmp/iree/colab_artifacts'\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "%%capture\n",
        "!python -m pip install --upgrade tensorflow"
      ],
      "metadata": {
        "id": "y9KOsqosg6Ms"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "# Print version information for future notebook users to reference.\n",
        "print(\"TensorFlow version: \", tf.__version__)"
      ],
      "metadata": {
        "id": "SdCAvI3sqBO7",
        "outputId": "895381a9-b892-4855-f0c7-4e30d0e37988",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "TensorFlow version:  2.16.1\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dBHgjTjGPOJ7"
      },
      "source": [
        "## Create a program using TensorFlow and import it into IREE\n",
        "\n",
        "NOTE: as in other domains, providing more information to a compiler allows it\n",
        "to generate more efficient code. As a general rule, the slowest varying\n",
        "dimensions of program data like batch index or timestep are safer to treat as\n",
        "dynamic than faster varying dimensions like image x/y/channel. See\n",
        "[this paper](https://arxiv.org/pdf/2006.03031.pdf) for a discussion of the\n",
        "challenges imposed by dynamic shapes and one project's approach to addressing\n",
        "them."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hwApbPstraWZ"
      },
      "source": [
        "#@title Define a sample TensorFlow module using dynamic shapes\n",
        "\n",
        "class DynamicShapesModule(tf.Module):\n",
        "  # reduce_sum_1d (dynamic input size, static output size)\n",
        "  #   e.g. [1, 2, 3] -> 6\n",
        "  @tf.function(input_signature=[tf.TensorSpec([None], tf.int32)])\n",
        "  def reduce_sum_1d(self, values):\n",
        "    return tf.math.reduce_sum(values)\n",
        "\n",
        "  # reduce_sum_2d (partially dynamic input size, static output size)\n",
        "  #   e.g. [[1, 2, 3], [10, 20, 30]] -> [11, 22, 33]\n",
        "  @tf.function(input_signature=[tf.TensorSpec([None, 3], tf.int32)])\n",
        "  def reduce_sum_2d(self, values):\n",
        "    return tf.math.reduce_sum(values, 0)\n",
        "\n",
        "  # add_one (dynamic input size, dynamic output size)\n",
        "  #   e.g. [1, 2, 3] -> [2, 3, 4]\n",
        "  @tf.function(input_signature=[tf.TensorSpec([None], tf.int32)])\n",
        "  def add_one(self, values):\n",
        "    return tf.math.add(values, tf.constant(1, dtype=tf.int32))"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "k4aMPI2C7btB"
      },
      "source": [
        "%%capture\n",
        "!python -m pip install --pre iree-base-compiler iree-base-runtime iree-tools-tf -f https://iree.dev/pip-release-links.html"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Print version information for future notebook users to reference.\n",
        "!iree-compile --version"
      ],
      "metadata": {
        "id": "4T4NERQ2qC3V",
        "outputId": "1af9caa7-ef9d-4b0d-de6b-b4ee4e61c81a",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "IREE (https://iree.dev):\n",
            "  IREE compiler version 20240511.890 @ a3b7e12f1ae3b4d0da9cc5dfa5fb7865b178ec4b\n",
            "  LLVM version 19.0.0git\n",
            "  Optimized build\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3nSXZiZ_X8-P",
        "outputId": "7547248f-65f3-417b-8f87-416801a1cb83"
      },
      "source": [
        "#@title Import the TensorFlow program into IREE as MLIR\n",
        "\n",
        "from IPython.display import clear_output\n",
        "\n",
        "from iree.compiler import tf as tfc\n",
        "\n",
        "compiler_module = tfc.compile_module(\n",
        "    DynamicShapesModule(), import_only=True,\n",
        "    output_mlir_debuginfo=False)\n",
        "clear_output()  # Skip over TensorFlow's output.\n",
        "\n",
        "# Save the imported MLIR to disk.\n",
        "imported_mlirbc_path = os.path.join(ARTIFACTS_DIR, \"dynamic_shapes.mlirbc\")\n",
        "with open(imported_mlirbc_path, \"wb\") as output_file:\n",
        "  output_file.write(compiler_module)\n",
        "print(f\"Wrote MLIR to path '{imported_mlirbc_path}'\")\n",
        "\n",
        "# Copy MLIR bytecode to MLIR text and see how the compiler views this program.\n",
        "imported_mlir_path = os.path.join(ARTIFACTS_DIR, \"dynamic_shapes.mlir\")\n",
        "!iree-ir-tool copy {imported_mlirbc_path} -o {imported_mlir_path}\n",
        "print(\"Dynamic Shapes MLIR:\")\n",
        "!cat {imported_mlir_path}"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Wrote MLIR to path '/tmp/iree/colab_artifacts/dynamic_shapes.mlirbc'\n",
            "Dynamic Shapes MLIR:\n",
            "module {\n",
            "  func.func @add_one(%arg0: tensor<?xi32>) -> tensor<?xi32> {\n",
            "    %c = stablehlo.constant dense<1> : tensor<i32>\n",
            "    %0 = shape.shape_of %arg0 : tensor<?xi32> -> tensor<1xindex>\n",
            "    %1 = stablehlo.dynamic_broadcast_in_dim %c, %0, dims = [] : (tensor<i32>, tensor<1xindex>) -> tensor<?xi32>\n",
            "    %2 = stablehlo.add %arg0, %1 : tensor<?xi32>\n",
            "    return %2 : tensor<?xi32>\n",
            "  }\n",
            "  func.func @reduce_sum_1d(%arg0: tensor<?xi32>) -> tensor<i32> {\n",
            "    %c = stablehlo.constant dense<0> : tensor<i32>\n",
            "    %0 = stablehlo.reduce(%arg0 init: %c) applies stablehlo.add across dimensions = [0] : (tensor<?xi32>, tensor<i32>) -> tensor<i32>\n",
            "    return %0 : tensor<i32>\n",
            "  }\n",
            "  func.func @reduce_sum_2d(%arg0: tensor<?x3xi32>) -> tensor<3xi32> {\n",
            "    %c = stablehlo.constant dense<0> : tensor<i32>\n",
            "    %0 = stablehlo.reduce(%arg0 init: %c) applies stablehlo.add across dimensions = [0] : (tensor<?x3xi32>, tensor<i32>) -> tensor<3xi32>\n",
            "    return %0 : tensor<3xi32>\n",
            "  }\n",
            "}"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WCiRV6KRh3iA"
      },
      "source": [
        "## Test the imported program\n",
        "\n",
        "_Note: you can stop after each step and use intermediate outputs with other tools outside of Colab._\n",
        "\n",
        "_See the [README](https://github.com/iree-org/iree/tree/main/samples/dynamic_shapes#instructions) for more details and example command line instructions._\n",
        "\n",
        "* _The \"imported MLIR\" can be used by IREE's generic compiler tools_\n",
        "* _The \"flatbuffer blob\" can be saved and used by runtime applications_\n",
        "\n",
        "_The specific point at which you switch from Python to native tools will depend on your project._"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GF0dzDsbaP2w",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "430bda38-44d2-454e-d652-1811f189faa9"
      },
      "source": [
        "#@title Compile the imported MLIR further into an IREE VM bytecode module\n",
        "\n",
        "from iree.compiler import compile_str\n",
        "\n",
        "# Note: we'll use the LLVM CPU backend since it has the best support\n",
        "# for dynamic shapes among our compiler targets.\n",
        "\n",
        "flatbuffer_blob = compile_str(compiler_module, target_backends=[\"llvm-cpu\"], input_type=\"stablehlo\")\n",
        "\n",
        "# Save the compiled program to disk.\n",
        "flatbuffer_path = os.path.join(ARTIFACTS_DIR, \"dynamic_shapes_cpu.vmfb\")\n",
        "with open(flatbuffer_path, \"wb\") as output_file:\n",
        "  output_file.write(flatbuffer_blob)\n",
        "print(f\"Wrote compiled program to path '{flatbuffer_path}'\")"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Wrote compiled program to path '/tmp/iree/colab_artifacts/dynamic_shapes_cpu.vmfb'\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "h8cmF6nAfza0",
        "outputId": "099a0715-c194-4091-9338-ae17b1cfcb3a",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "source": [
        "#@title Test running the compiled VM module using IREE's runtime\n",
        "\n",
        "from iree import runtime as ireert\n",
        "\n",
        "config = ireert.Config(\"local-task\")\n",
        "ctx = ireert.SystemContext(config=config)\n",
        "vm_module = ireert.VmModule.from_flatbuffer(ctx.instance, flatbuffer_blob)\n",
        "ctx.add_vm_module(vm_module)"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "<ipython-input-10-e57c61828074>:7: UserWarning: Making copy of unaligned VmModule buffer. It is recommended to make this deterministic by calling `copy_buffer` to always make a copy or `mmap` to efficiently load from a file. This warning can be silenced by adding `warn_if_copy=False` to `from_buffer`\n",
            "  vm_module = ireert.VmModule.from_flatbuffer(ctx.instance, flatbuffer_blob)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "CQffg1iQatkb",
        "outputId": "ebf76b42-2fd6-4eda-dcd7-3049e431593a"
      },
      "source": [
        "import numpy as np\n",
        "\n",
        "# Our @tf.functions are accessible by name on the module named 'module'\n",
        "dynamic_shapes_program = ctx.modules.module\n",
        "\n",
        "print(dynamic_shapes_program.reduce_sum_1d(np.array([1, 10, 100], dtype=np.int32)).to_host())\n",
        "print(dynamic_shapes_program.reduce_sum_2d(np.array([[1, 2, 3], [10, 20, 30]], dtype=np.int32)).to_host())\n",
        "print(dynamic_shapes_program.reduce_sum_2d(np.array([[1, 2, 3], [10, 20, 30], [100, 200, 300]], dtype=np.int32)).to_host())\n",
        "print(dynamic_shapes_program.add_one(np.array([1, 10, 100], dtype=np.int32)).to_host())"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "111\n",
            "[11 22 33]\n",
            "[111 222 333]\n",
            "[  2  11 101]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wCvwX1IEokm6"
      },
      "source": [
        "## Download compilation artifacts"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bUaNUkS2ohRj",
        "outputId": "2eb568ef-3345-4fcd-d42b-a2790c59d298",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 104
        }
      },
      "source": [
        "ARTIFACTS_ZIP = \"/tmp/dynamic_shapes_colab_artifacts.zip\"\n",
        "\n",
        "print(f\"Zipping '{ARTIFACTS_DIR}' to '{ARTIFACTS_ZIP}' for download...\")\n",
        "!cd {ARTIFACTS_DIR} && zip -r {ARTIFACTS_ZIP} .\n",
        "\n",
        "# Note: you can also download files using Colab's file explorer\n",
        "try:\n",
        "  from google.colab import files\n",
        "  print(\"Downloading the artifacts zip file...\")\n",
        "  files.download(ARTIFACTS_ZIP)\n",
        "except ImportError:\n",
        "  print(\"Missing google_colab Python package, can't download files\")"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Zipping '/tmp/iree/colab_artifacts' to '/tmp/dynamic_shapes_colab_artifacts.zip' for download...\n",
            "  adding: dynamic_shapes.mlirbc (deflated 32%)\n",
            "  adding: dynamic_shapes_cpu.vmfb (deflated 64%)\n",
            "  adding: dynamic_shapes.mlir (deflated 70%)\n",
            "Downloading the artifacts zip file...\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "download(\"download_9a77ad78-5868-4873-a8c8-a3393e61fa5d\", \"dynamic_shapes_colab_artifacts.zip\", 6179)"
            ]
          },
          "metadata": {}
        }
      ]
    }
  ]
}
